﻿#pragma kernel SolveDiffuseContacts

#include "GridUtils.cginc"
#include "CollisionMaterial.cginc"
#include "ContactHandling.cginc"
#include "ColliderDefinitions.cginc"
#include "Rigidbody.cginc"
#include "Simplex.cginc"
#include "MathUtils.cginc"
#include "Bounds.cginc"
#include "SolverParameters.cginc"
#include "Optimization.cginc"
#include "DistanceFunctions.cginc"

#define MAX_CONTACTS_PER_DIFFUSE 32

StructuredBuffer<float4> inputPositions; 
StructuredBuffer<float4> inputAttributes; 
RWStructuredBuffer<float4> inputVelocities; 

StructuredBuffer<aabb> aabbs;
StructuredBuffer<transform> transforms;
StructuredBuffer<shape> shapes;

StructuredBuffer<uint> sortedColliderIndices;
StructuredBuffer<uint> cellOffsets;
StructuredBuffer<uint> cellCounts;

// triangle mesh data:
StructuredBuffer<TriangleMeshHeader> triangleMeshHeaders;
StructuredBuffer<BIHNode> bihNodes;
StructuredBuffer<Triangle> triangles;
StructuredBuffer<float3> vertices;

// edge mesh data:
StructuredBuffer<EdgeMeshHeader> edgeMeshHeaders;
StructuredBuffer<BIHNode> edgeBihNodes;
StructuredBuffer<Edge> edges;
StructuredBuffer<float2> edgeVertices;

// heightfield data:
StructuredBuffer<HeightFieldHeader> heightFieldHeaders;
StructuredBuffer<float> heightFieldSamples;

// distance field data:
StructuredBuffer<DistanceFieldHeader> distanceFieldHeaders;
StructuredBuffer<DFNode> dfNodes;

StructuredBuffer<transform> solverToWorld;
StructuredBuffer<transform> worldToSolver;

StructuredBuffer<uint> dispatch;

float radiusScale;
uint colliderCount;    // amount of colliders in the grid.
uint cellsPerCollider; // max amount of cells a collider can be inserted into. Typically this is 8.
int shapeTypeCount;    // number of different collider shapes, ie: box, sphere, sdf, etc.
float deltaTime;


void CollideMesh(int colliderIndex, int threadIndex, aabb particleBounds, inout float4 pos, float radius)
{
    shape s = shapes[colliderIndex];
    if (s.dataIndex < 0) return;

    TriangleMeshHeader header = triangleMeshHeaders[s.dataIndex];
    
    TriangleMesh meshShape;
    meshShape.colliderToSolver = worldToSolver[0].Multiply(transforms[colliderIndex]);
    meshShape.s = shapes[colliderIndex];

    // invert a full matrix here to accurately represent collider bounds scale.
    float4x4 solverToCollider = Inverse(TRS(meshShape.colliderToSolver.translation.xyz, meshShape.colliderToSolver.rotation, meshShape.colliderToSolver.scale.xyz));
    aabb simplexBound = particleBounds.Transformed(solverToCollider);  // TODO: this is wrong, passed bounds are in world space!

    float4 marginCS = float4((s.contactOffset + collisionMargin) / meshShape.colliderToSolver.scale.xyz, 0);
    
    int stack[12]; 
    int stackTop = 0;

    stack[stackTop++] = 0;

    while (stackTop > 0)
    {
        // pop node index from the stack:
        int nodeIndex = stack[--stackTop];
        BIHNode node = bihNodes[header.firstNode + nodeIndex];

        // leaf node:
        if (node.firstChild < 0)
        {
            // check for contact against all triangles:
            for (int dataOffset = node.start; dataOffset < node.start + node.count; ++dataOffset)
            {
                Triangle t = triangles[header.firstTriangle + dataOffset];
                float4 v1 = float4(vertices[header.firstVertex + t.i1], 0);
                float4 v2 = float4(vertices[header.firstVertex + t.i2], 0);
                float4 v3 = float4(vertices[header.firstVertex + t.i3], 0);
                aabb triangleBounds;
                triangleBounds.FromTriangle(v1, v2, v3, marginCS);

                if (triangleBounds.IntersectsAabb(simplexBound, s.is2D()))
                {
                    meshShape.tri.Cache(v1 * meshShape.colliderToSolver.scale, v2 * meshShape.colliderToSolver.scale, v3 * meshShape.colliderToSolver.scale);

                    SurfacePoint surf;
                    meshShape.Evaluate(pos, float4(radius, radius, radius, 0), QUATERNION_IDENTITY, surf);

                    float dist = dot(pos - surf.pos, surf.normal) - radius;
                    if (dist < 0) 
                        pos = surf.pos + surf.normal * radius;
                }
            }
        }
        else // check min and/or max children:
        {
            // visit min node:
            if (simplexBound.min_[node.axis] <= node.min_)
                stack[stackTop++] = node.firstChild;

            // visit max node:
            if (simplexBound.max_[node.axis] >= node.max_)
                stack[stackTop++] = node.firstChild + 1;
        }
    }
}

void CollideEdgeMesh(int colliderIndex, int threadIndex, aabb particleBounds, inout float4 pos, float radius)
{
    shape s = shapes[colliderIndex];
    if (s.dataIndex < 0) return;

    EdgeMeshHeader header = edgeMeshHeaders[s.dataIndex];
    
    EdgeMesh meshShape;
    meshShape.colliderToSolver = worldToSolver[0].Multiply(transforms[colliderIndex]);
    meshShape.s = shapes[colliderIndex];

    // invert a full matrix here to accurately represent collider bounds scale.
    float4x4 solverToCollider = Inverse(TRS(meshShape.colliderToSolver.translation.xyz, meshShape.colliderToSolver.rotation, meshShape.colliderToSolver.scale.xyz));
    aabb simplexBound = particleBounds.Transformed(solverToCollider);  // TODO: this is wrong, passed bounds are in world space!
    //simplexBound.Expand(0.02);

    float4 marginCS = float4((s.contactOffset + collisionMargin) / meshShape.colliderToSolver.scale.xyz, 0);
    
    int stack[12]; 
    int stackTop = 0;

    stack[stackTop++] = 0;

    while (stackTop > 0)
    {
        // pop node index from the stack:
        int nodeIndex = stack[--stackTop];
        BIHNode node = edgeBihNodes[header.firstNode + nodeIndex];

        // leaf node:
        if (node.firstChild < 0)
        {
            // check for contact against all triangles:
            for (int dataOffset = node.start; dataOffset < node.start + node.count; ++dataOffset)
            {
                Edge t = edges[header.firstEdge + dataOffset];
                float4 v1 = float4(edgeVertices[header.firstVertex + t.i1],0,0) + s.center;
                float4 v2 = float4(edgeVertices[header.firstVertex + t.i2],0,0) + s.center;
                aabb edgeBounds;
                edgeBounds.FromEdge(v1, v2, marginCS);

                if (edgeBounds.IntersectsAabb(simplexBound, s.is2D()))
                {
                    meshShape.edge.Cache(v1 * meshShape.colliderToSolver.scale, v2 * meshShape.colliderToSolver.scale);

                    SurfacePoint surf;
                    meshShape.Evaluate(pos, float4(radius, radius, radius, 0), QUATERNION_IDENTITY, surf);

                    float dist = dot(pos - surf.pos, surf.normal) - radius;
                    if (dist < 0) 
                        pos = surf.pos + surf.normal * radius;
                }
            }
        }
        else // check min and/or max children:
        {
            // visit min node:
            if (simplexBound.min_[node.axis] <= node.min_)
                stack[stackTop++] = node.firstChild;

            // visit max node:
            if (simplexBound.max_[node.axis] >= node.max_)
                stack[stackTop++] = node.firstChild + 1;
        }
    }
}

void CollideHeightmap(int colliderIndex, int threadIndex, aabb particleBounds, inout float4 pos, float radius)
{
    shape s = shapes[colliderIndex];
    if (s.dataIndex < 0) return;

    HeightFieldHeader header = heightFieldHeaders[s.dataIndex];
    
    Heightfield fieldShape;
    fieldShape.colliderToSolver = worldToSolver[0].Multiply(transforms[colliderIndex]);
    fieldShape.s = s;
    
    // invert a full matrix here to accurately represent collider bounds scale.
    float4x4 solverToCollider = Inverse(TRS(fieldShape.colliderToSolver.translation.xyz, fieldShape.colliderToSolver.rotation, fieldShape.colliderToSolver.scale.xyz));
    aabb simplexBound = particleBounds.Transformed(solverToCollider);
    
    int resolutionU = (int)s.center.x;
    int resolutionV = (int)s.center.y;

    // calculate terrain cell size:
    float cellWidth = s.size.x / (resolutionU - 1);
    float cellHeight = s.size.z / (resolutionV - 1);

    // calculate particle bounds min/max cells:
    int2 min_ = int2((int)floor(simplexBound.min_[0] / cellWidth), (int)floor(simplexBound.min_[2] / cellHeight));
    int2 max_ = int2((int)floor(simplexBound.max_[0] / cellWidth), (int)floor(simplexBound.max_[2] / cellHeight));

    for (int su = min_[0]; su <= max_[0]; ++su)
    {
        if (su >= 0 && su < resolutionU - 1)
        {
            for (int sv = min_[1]; sv <= max_[1]; ++sv)
            {
                if (sv >= 0 && sv < resolutionV - 1)
                {
                    // calculate neighbor sample indices:
                    int csu1 = clamp(su + 1, 0, resolutionU - 1);
                    int csv1 = clamp(sv + 1, 0, resolutionV - 1);

                    // sample heights:
                    float h1 = heightFieldSamples[header.firstSample + sv * resolutionU + su] * s.size.y;
                    float h2 = heightFieldSamples[header.firstSample + sv * resolutionU + csu1] * s.size.y;
                    float h3 = heightFieldSamples[header.firstSample + csv1 * resolutionU + su] * s.size.y;
                    float h4 = heightFieldSamples[header.firstSample + csv1 * resolutionU + csu1] * s.size.y;

                    if (h1 < 0) continue;
                    h1 = abs(h1);
                    h2 = abs(h2);
                    h3 = abs(h3);
                    h4 = abs(h4);

                    float min_x = su * s.size.x / (resolutionU - 1);
                    float max_x = csu1 * s.size.x / (resolutionU - 1);
                    float min_z = sv * s.size.z / (resolutionV - 1);
                    float max_z = csv1 * s.size.z / (resolutionV - 1);

                    // ------contact against the first triangle------:
                    float4 v1 = float4(min_x, h3, max_z, 0);
                    float4 v2 = float4(max_x, h4, max_z, 0);
                    float4 v3 = float4(min_x, h1, min_z, 0);

                    fieldShape.tri.Cache(v1, v2, v3);
                    fieldShape.triNormal.xyz = normalizesafe(cross((v2 - v1).xyz, (v3 - v1).xyz));

                    SurfacePoint surf;
                    fieldShape.Evaluate(pos, float4(radius, radius, radius, 0), QUATERNION_IDENTITY, surf);

                    float dist = dot(pos - surf.pos, surf.normal) - radius;
                    if (dist < 0) 
                        pos = surf.pos + surf.normal * radius;

                    // ------contact against the second triangle------:
                    v1 = float4(min_x, h1, min_z, 0);
                    v2 = float4(max_x, h4, max_z, 0);
                    v3 = float4(max_x, h2, min_z, 0);

                    fieldShape.tri.Cache(v1, v2, v3);
                    fieldShape.triNormal.xyz = normalizesafe(cross((v2 - v1).xyz, (v3 - v1).xyz));
                    
                    fieldShape.Evaluate(pos, float4(radius, radius, radius, 0), QUATERNION_IDENTITY, surf);
                    
                    dist = dot(pos - surf.pos, surf.normal) - radius;
                    if (dist < 0) 
                        pos = surf.pos + surf.normal * radius;
                }
            }
        }
    }
}

[numthreads(128, 1, 1)]
void SolveDiffuseContacts (uint3 id : SV_DispatchThreadID) 
{
    unsigned int threadIndex = id.x;
    if (threadIndex >= dispatch[3]) return;

    uint cellCount = colliderCount * cellsPerCollider;
    int candidateCount = 0;
    uint candidates[MAX_CONTACTS_PER_DIFFUSE];
        
    float4 predPos = inputPositions[threadIndex] + inputVelocities[threadIndex] * deltaTime;
    float radius = inputAttributes[threadIndex].z * radiusScale;

    // max size of the particle bounds in cells:
    int4 maxSize = int4(3,3,3,3);

    aabb b;
    b.FromEdge(inputPositions[threadIndex], predPos, radius); 
    b.Transform(solverToWorld[0]);

    // build a list of candidate colliders:
    for (uint m = 1; m <= levelPopulation[0]; ++m)
    {
        uint l = levelPopulation[m];
        float cellSize = CellSizeOfLevel(l);

        int4 minCell = floor(b.min_ / cellSize);
        int4 maxCell = floor(b.max_ / cellSize);
        maxCell = minCell + min(maxCell - minCell, maxSize);

        for (int x = minCell[0]; x <= maxCell[0]; ++x)
        {
            for (int y = minCell[1]; y <= maxCell[1]; ++y)
            {
                // for 2D mode, project each cell at z == 0 and check them too. This way we ensure 2D colliders
                // (which are inserted in cells with z == 0) are accounted for in the broadphase.
                if (mode == 1)
                {
                    uint flatCellIndex = GridHash(int4(x,y,0,l));
                    uint cellStart = cellOffsets[flatCellIndex];
                    uint cellCount = cellCounts[flatCellIndex];

                    // iterate through colliders in the neighbour cell
                    for (uint n = cellStart; n < cellStart + cellCount; ++n)
                    {
                        // sorted insert into the candidates list:
                        if (candidateCount < MAX_CONTACTS_PER_DIFFUSE)
                            candidates[candidateCount++] = sortedColliderIndices[n] / cellsPerCollider;
                    }
                }

                for (int z = minCell[2]; z <= maxCell[2]; ++z)
                {
                    uint flatCellIndex = GridHash(int4(x,y,z,l));
                    uint cellStart = cellOffsets[flatCellIndex];
                    uint cellCount = cellCounts[flatCellIndex];

                    // iterate through colliders in the neighbour cell
                    for (uint n = cellStart; n < cellStart + cellCount; ++n)
                    {
                        if (candidateCount < MAX_CONTACTS_PER_DIFFUSE)
                            candidates[candidateCount++] = sortedColliderIndices[n] / cellsPerCollider;
                    }
                   
                }
            }
        }
    }
    
    //evaluate candidates and create contacts: 
    if (candidateCount > 0)
    {
   
        // insert sort:
        for (int k = 1; k < candidateCount; ++k)
        {
            uint key = candidates[k];
            int j = k - 1;

            while (j >= 0 && candidates[j] > key)
                candidates[j + 1] = candidates[j--];

            candidates[j + 1] = key;
        }

        // make sure each candidate only shows up once in the list:
        int first = 0, contactCount = 0;
        while(++first != candidateCount)
        {
            if (candidates[contactCount] != candidates[first])
                candidates[++contactCount] = candidates[first];
        }
        contactCount++;

        // solve contacts:
        for (int i = 0; i < contactCount; i++)
        {
            int c = candidates[i];
           
            aabb colliderBoundsWS = aabbs[c];
            
            if (b.IntersectsAabb(colliderBoundsWS, mode == 1))
            {
           
                switch(shapes[c].type)
                {
                    case SPHERE_SHAPE:
                    {
                        SurfacePoint surf;
                        Sphere sphereShape;
                        sphereShape.colliderToSolver = worldToSolver[0].Multiply(transforms[c]);
                        sphereShape.s = shapes[c];
                        sphereShape.Evaluate(predPos, float4(radius, radius, radius, 0), QUATERNION_IDENTITY, surf);
                        
                        float dist = dot(predPos - surf.pos, surf.normal) - radius;
                        if (dist < 0) predPos = surf.pos + surf.normal * radius;
                    }
                    break;

                    case BOX_SHAPE:
                    {
                        SurfacePoint surf;
                        Box boxShape;
                        boxShape.colliderToSolver = worldToSolver[0].Multiply(transforms[c]);
                        boxShape.s = shapes[c];
                        boxShape.Evaluate(predPos, float4(radius, radius, radius, 0), QUATERNION_IDENTITY, surf);

                        float dist = dot(predPos - surf.pos, surf.normal) - radius;
                        if (dist < 0) predPos = surf.pos + surf.normal * radius;
                    }
                    break;

                    case CAPSULE_SHAPE:
                    {
                        SurfacePoint surf;
                        Capsule capShape;
                        capShape.colliderToSolver = worldToSolver[0].Multiply(transforms[c]);
                        capShape.s = shapes[c];
                        capShape.Evaluate(predPos, float4(radius, radius, radius, 0), QUATERNION_IDENTITY, surf);

                        float dist = dot(predPos - surf.pos, surf.normal) - radius;
                        if (dist < 0) predPos = surf.pos + surf.normal * radius;
                    }
                    break;

                    case TRIANGLE_MESH_SHAPE:
                    {
                        CollideMesh(c, threadIndex, b, predPos, radius);
                    }
                    break; 

                    case EDGE_MESH_SHAPE:
                    {
                        CollideEdgeMesh(c, threadIndex, b, predPos, radius);
                    }
                    break;

                    case HEIGHTMAP_SHAPE:
                    {
                        CollideHeightmap(c, threadIndex, b, predPos, radius);
                    }
                    break;

                    case SDF_SHAPE:
                    {
                        SurfacePoint surf;
                        DistanceField dfShape;
                        dfShape.colliderToSolver = worldToSolver[0].Multiply(transforms[c]);
                        dfShape.s = shapes[c];
                        dfShape.distanceFieldHeaders = distanceFieldHeaders;
                        dfShape.dfNodes = dfNodes;
                        dfShape.Evaluate(predPos, float4(radius, radius, radius, 0), QUATERNION_IDENTITY, surf);

                        float dist = dot(predPos - surf.pos, surf.normal) - radius;
                        if (dist < 0) predPos = surf.pos + surf.normal * radius;
                    }
                    break;
                }
            }
        }
    }   

    inputVelocities[threadIndex].xyz = (predPos.xyz - inputPositions[threadIndex].xyz) / deltaTime;
}